热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

尺寸|更多_VGGpytorch实现

篇首语:本文由编程笔记#小编为大家整理,主要介绍了VGG-pytorch实现相关的知识,希望对你有一定的参考价值。 VGG 1.网络结构 如图可见,VGG网络的构造很简单&#xff0

篇首语:本文由编程笔记#小编为大家整理,主要介绍了VGG-pytorch实现相关的知识,希望对你有一定的参考价值。



VGG

1.网络结构

如图可见,VGG网络的构造很简单,通过不断地卷积,池化,扩大通道数,降低宽高,最终平展为一维数据再进行softmax分类。相较于AlexNet而言,VGG最大的特征就是降低了卷积核尺寸,增加了卷积核的深度层数,拥有更多的非线性变换,增加了CNN对特征的学习能力。


2.pytorch网络设计

这里采用的数据集为FashionMNIST数据集,慢慢地往后的文章也会引入更多的数据集使用,Fashion MNIST包含了10种类别70000个不同时尚穿戴品的图像,整体数据结构上跟MNIST完全一致。每张图像的尺寸同样是28*28,但下载下来的数据通道数为1。

#定义块
def vgg_block(num_convs, in_channels, num_channels):
layers = []
for i in range(num_convs):
layers += [nn.Conv2d(in_channels=in_channels, out_channels=num_channels, kernel_size=3, padding=1)]
in_channels = num_channels
layers += [nn.ReLU()]
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
return nn.Sequential(*layers)
# 网络定义
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
# 这里适配输入为3x224x224的图片
self.conv_arch = ((1, 3, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
self.conv_arch_28x28 =((2, 256, 512), (2, 512, 512))
# 这里为了适配1x28x28的输入图片大小,对原始网络层做些修改
#前四层不做池化,保留原始特征
self.conv_28x28=nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
self.conv_28x28_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.conv_28x28_3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
self.conv_28x28_4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
#后4层使用VGG块构造
layers = []
for (num_convs, in_channels, num_channels) in self.conv_arch_28x28:
layers += [vgg_block(num_convs, in_channels, num_channels)]
self.features = nn.Sequential(*layers)
self.Linear = nn.Linear(512 * 7 * 7, 4096)
self.drop1 = nn.Dropout(0.5)
self.Linear2 = nn.Linear(4096, 4096)
self.drop2 = nn.Dropout(0.5)
self.Linear3 = nn.Linear(4096, 10)
def forward(self, x):
x=F.relu(self.conv_28x28(x))
x = F.relu(self.conv_28x28_2(x))
x = F.relu(self.conv_28x28_3(x))
x = F.relu(self.conv_28x28_4(x))
x = self.features(x)
x = x.view(-1, 512 * 7 * 7)
x = self.Linear3(self.drop2(F.relu(self.Linear2(self.drop1(F.relu(self.Linear(x)))))))
return x

3.网络测试


1.数据集读取分类

# 数据增强
draw = draw_tool.draw_tool()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(),
transforms.ToTensor()])
# 验证集不增强
transform1 = transforms.Compose(
[
transforms.ToTensor()])
train_set = torchvision.datasets.FashionMNIST(root='F:\\\\pycharm\\\\dataset', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=30,
shuffle=True, num_workers=2)
test_set = torchvision.datasets.FashionMNIST(root='F:\\\\pycharm\\\\dataset', train=False,
download=True, transform=transform1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=30,
shuffle=False, num_workers=2)

2.模型训练设置

model = VGG()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
model = model.to(device)

3.训练

if __name__ == '__main__':
for epoch in range(4):
train(epoch)
torch.save(model.state_dict(), "minist_last.pth")
draw.show()

训练部分,可能是由于网络太大,或者是数据集太多的缘故,跑得非常慢,所以这里只针对整个数据集进行了4个epoch训练,训练测试结果如下:

最后一次训练的精度达到了86.77%,但明显可以看出还可以继续增加。


4.总结

​ VGG16相比AlexNet的一个改进是采用连续的几个3x3的卷积核代替AlexNet中的较大卷积核(11x11,7x7,5x5)。对于给定的感受野,采用堆积的小卷积核是优于采用大的卷积核,因为多层非线性层可以增加网络深度来保证学习更复杂的模式,而且代价还比较小(参数更少)。

​ 在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,提升了网络的深度,在一定程度上提升了神经网络的效果。这点我认为应该是把卷积宽高改革为卷积层数,能更好地去调整参数。

​ 使用3x3卷积核的好处:减少了总体传入显卡的参数,且有利于保护图像的原始性质。
最后非常希望有一样的初学者或者大佬能多评论留言,一起分享一下过程和经历,感激不尽。


5.补充

最近学着用tensorboard,又跑了一遍,记录一下效果。


推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 很多时候在注册一些比较重要的帐号,或者使用一些比较重要的接口的时候,需要使用到随机字符串,为了方便,我们设计这个脚本需要注意 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • Nginx使用(server参数配置)
    本文介绍了Nginx的使用,重点讲解了server参数配置,包括端口号、主机名、根目录等内容。同时,还介绍了Nginx的反向代理功能。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • C# 7.0 新特性:基于Tuple的“多”返回值方法
    本文介绍了C# 7.0中基于Tuple的“多”返回值方法的使用。通过对C# 6.0及更早版本的做法进行回顾,提出了问题:如何使一个方法可返回多个返回值。然后详细介绍了C# 7.0中使用Tuple的写法,并给出了示例代码。最后,总结了该新特性的优点。 ... [详细]
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • 突破MIUI14限制,自定义胶囊图标、大图标样式,支持任意APP
    本文介绍了如何突破MIUI14的限制,实现自定义胶囊图标和大图标样式,并支持任意APP。需要一定的动手能力和主题设计师账号权限或者会主题pojie。详细步骤包括应用包名获取、素材制作和封包获取等。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • Android工程师面试准备及设计模式使用场景
    本文介绍了Android工程师面试准备的经验,包括面试流程和重点准备内容。同时,还介绍了建造者模式的使用场景,以及在Android开发中的具体应用。 ... [详细]
  • 本文介绍了操作系统的定义和功能,包括操作系统的本质、用户界面以及系统调用的分类。同时还介绍了进程和线程的区别,包括进程和线程的定义和作用。 ... [详细]
  • STL迭代器的种类及其功能介绍
    本文介绍了标准模板库(STL)定义的五种迭代器的种类和功能。通过图表展示了这几种迭代器之间的关系,并详细描述了各个迭代器的功能和使用方法。其中,输入迭代器用于从容器中读取元素,输出迭代器用于向容器中写入元素,正向迭代器是输入迭代器和输出迭代器的组合。本文的目的是帮助读者更好地理解STL迭代器的使用方法和特点。 ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • [翻译]PyCairo指南裁剪和masking
    裁剪和masking在PyCairo指南的这个部分,我么将讨论裁剪和masking操作。裁剪裁剪就是将图形的绘制限定在一定的区域内。这样做有一些效率的因素࿰ ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了htmlJS相关的知识,希望对你有一定的参考价值。 ... [详细]
author-avatar
手机用户2602919063
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有